{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "\n",
    "\n",
    "损失函数是模型优化的目标，用于衡量在无数的参数取值中，哪一个是最理想的。损失函数的计算在训练过程的代码中，每一轮训练代码的过程均相同：先根据输入数据正向计算预测输出，再根据预测值和真实值计算损失，最后根据损失反向传播梯度并更新参数。\n",
    "\n",
    "在之前的方案中，我们复用了房价预测模型的损失函数-均方误差。虽然从预测效果来看，使用均方误差使得损失不断下降，模型的预测值逐渐逼近真实值，但模型的最终效果不够理想。究其根本，不同的机器学习任务有各自适宜的损失函数，这里我们详细剖析一下其中的缘由：\n",
    "\n",
    "房价预测是回归任务，而手写数字识别属于分类任务。分类任务中使用均方误差作为损失存在逻辑和效果上的缺欠，比如房价可以是0-9之间的任何浮点数，手写数字识别的数字只可能是0-9之间的10个实数值（标签）。在房价预测的案例中，因为房价本身是一个连续的实数值，以模型输出的数值和真实房价差距作为损失函数（loss）是符合道理的。但对于分类问题，真实结果是标签，而模型输出是实数值，导致两者相减的物理含义缺失。\n",
    "\n",
    "这里我们做一个假设：如果模型能输出10个标签的概率，对应真实标签的概率输出尽可能接近100%，而其他标签的概率输出尽可能接近0%，且所有输出概率之和为1。这是一种更合理的假设！与此对应，真实的标签值可以转变成一个10维度的one-hot向量，在对应数字的位置上为1，其余位置为0，比如标签“6”可以转变成[0,0,0,0,0,1,0,0,0,0]。\n",
    "\n",
    "为了实现上述假设，需要引入Softmax函数。它可以将原始输出转变成对应标签的概率，**Softmax函数公式如下**：\n",
    "\n",
    "$$softmax(x_i) = \\frac {e^{x_i}}{\\sum_{j=0}^N{e^x_j}}, i=0, ..., C-1$$\n",
    "\n",
    "\n",
    "\n",
    "$C$是标签类别个数。\n",
    "从公式的形式可见，每个输出的范围均在0~1之间，且所有输出之和等于1，这是这种变换后可被解释成概率的基本前提。对应到代码上，我们需要在网络定义部分修改输出层：self.fc = FC(name_scope, size=10, act='softmax')，即是对全连接层FC的输出加一个softmax运算。\n",
    "\n",
    "\n",
    "\n",
    "   \n",
    "在该假设下，采用均方误差衡量两个概率的差别不是理论上最优的。人们习惯使用交叉熵误差作为分类问题的损失衡量，因为后者有更合理的物理解释，详见《机器学习的思考故事》。\n",
    "\n",
    "**交叉熵的公式如下：**\n",
    "\n",
    "$$ L = -[\\sum_{k=1}^{n} t_k\\log y_k +(1- t_k)\\log(1-y_k)] $$\n",
    "   \n",
    "其中，$\\log$表示以$e$为底数的自然对数。$y_k$代表模型输出，$t_k$代表各个标签。$t_k$中只有正确解的标签为1，其余均为0（one-hot表示）。因此，交叉熵只计算对应着“正确解”标签的输出的自然对数。比如，假设正确标签的索引是“2”，与之对应的神经网络的输出是0.6，则交叉熵误差是$−\\log 0.6 = 0.51$；若“2”对应的输出是0.1，则交叉熵误差为$−\\log 0.1 = 2.30$。由此可见，交叉熵误差的值是由正确标签所对应的输出结果决定的。\n",
    "\n",
    "自然对数的函数曲线可由如下代码显示。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAEWCAYAAABmE+CbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xl8VPW9//HXJ4Q1K1lIICEkgbCEVQyg1WoVpNhq0brUpdeqtXb5qX1Y29tFW9vb2tvVLtfb22LVWlu11VZrq3WrWtxAgsoiEAJZIEB2spP9+/tjBhspIYFk5mRm3s/HYx7O5BzO+XwJnvd8v9+zmHMOERGRKK8LEBGRkUGBICIigAJBRET8FAgiIgIoEERExE+BICIigAJBZFDM7CUzu26YtpVvZoVmZoNY909mdu5w7FdkIAoEkeD7NvAjN7iLgL4PfCfA9YgACgSRoDKzycBZwOODWd859wYQb2YFAS1MBAWCRAAz+5KZ/emIn/3czH52gtuLMrPbzKzczKrN7LdmltBn+VX+ZXVm9nUzKzOzFf7F5wBvOufa/etON7N6M1vs/zzFzGrM7AN9dvkS8OETqVXkeCgQJBL8DlhlZokAZhYNXAb81sx+YWYN/bw297O9q/2vs4BcIBa4y7/tfOAXwJXAZCAByOjzZ+cDRYc/OOd2A18GfmdmE4D7gPudcy/1+TPbgYVDaL/IoCgQJOw55w4Aa4FL/D9aBdQ65zY65z7nnEvs57Wgn01eCdzpnCtxzrUAXwUu8wfNxcBfnXOvOOc6gW8AfecKEoHmI+q7G9gFrMcXIrcesb9m/58TCSgFgkSK+4GP+99/HHhgCNuaApT3+VwORANp/mV7Dy9wzrUBdX3WPQjEHWWbdwPzgP9xznUcsSwOaBhCvSKDokCQSPE4sMDM5gHnAb8HMLNfmllLP693+tnWfmBan89ZQDdQBRwAMg8vMLPxQHKfdTcDM/tuzMxigZ8C9wDfNLOkI/Y3B9h0fM0VOX4KBIkI/kncR4EHgTecc3v8P/+Mcy62n9fcfjb3EHCzmeX4D+bfBf7gnOv27+N8M3ufmY0Bvgn0vd7gOWCxmY3r87OfAYXOueuAJ4FfHrG/M4G/D6X9IoOhQJBIcj++Sd2hDBcB3OvfxlqgFGgHbgRwzr3jf/8wvt5CC1ANdPiXVwEvAKsBzGw1vjmNz/q3/QV8gXGlf/kSoMV/+qlIQJkekCORwsyygB1AunOuKUj7jMU3/p/nnCv1/ywfXzgtHejiNP/psvc4554KeLES8RQIEhHMLAq4E4h3zl0b4H2dD/wD31DRj4FlwOJBXpks4plorwsQCTQzi8E34VuOb3gm0FbjG1IyoBC4TGEgoUA9BBERATSpLCIifiE1ZJSSkuKys7O9LkNEJKRs3Lix1jmXOtB6IRUI2dnZFBYWel2GiEhIMbPygdfSkJGIiPgpEEREBFAgiIiInwJBREQABYKIiPh5GghmtsrMisxsl5l9xctaREQinWeBYGajgP8FzgXygcv9N/0SEREPeHkdwlJgl3OuBMDMHsZ3D5htHtYkIuI55xy1LZ2U17VSVtdGeV0rlxZMZWrShIDu18tAyKDPowaBCnx3hXwPM7seuB4gKysrOJWJiATY4YN+WV0rpbWtlNW2Ul7XRlmd778tHd3vrhtlsDhrYlgHwqA459YAawAKCgp0Jz4RCSmNbV2U1Lb4Dvw1rZTWtVFW6wuBvgf96Cgjc+J4slNiWJKdxLTkCWSnxJCdHENG4njGRAd+hN/LQNgHTO3zOdP/MxGRkNLe1cOe+jZKalooqfUd+Ev8B/361s5314syyJg4nuzkGC5anEFOSgzTUmLISY4hc+J4okd5e+Knl4GwAcgzsxx8QXAZcIWH9YiI9Ms5R01zB7tqWiipaaWkppXdNS2U1Law7+AhevuMX0yKG0tOSgwr89PITY0hJyWWnJQJTE2awNjoUd41YgCeBYJzrtvMbgCeAUYB9/qfRysi4pmunl7K69rYXdPCruoWdte0sLumlZLqFpr7DPGMHz2KnJQYFmYmcuFJmUxPjSEnxfeKGzfawxacOE/nEPzPidWzYkUk6A519rx70C+ubmZXte99eV0b3X2+7qfHj2P6pBguOCmD6akxTJ8US25qLJPjxxEVZR62YPiN+EllEZGhaOnopriqmWL/Af/w+30Nhzj8wMhRUca05AnMSI3lg3PTmTEplumpseSmhu63/ROhQBCRsNDW2U1xVQs7/Qf8ospmiqua2d/Y/u46Y6KjyE2J4aSsiVxy8lTy0mLJmxTLtOSYoJzFM9IpEEQkpHR291JS6zvg76xqpqjSFwJ76tveXWdMdBTTU2NZkpPEzLQ48ibFkpcWR1bSBEaF2TDPcFIgiMiI5JxjX8Mhiiqb2eF/FVU2UVLT+u4Yf3SUkZMSw/yMBC4+OZOZaXHMTPN949eB//gpEETEc60d3RRVNbP9QBM7DjSzo7KJHZXNNLf/66yejMTxzE6PY8WcNGalxzErPY7clFgN9QwjBYKIBI1zjv2N7Wzf38S2A01s97/K69veneCNGxvNrPQ4LliUwaz0OGb7D/6RNLnrFQWCiAREd08vu2taeWd/I9v8AbDtQBMNbV3vrjMteQJz0uO58KRM5kyOY87keDInjsdMwz1eUCCIyJC1d/VQVNnM1v2NvLO/iXf2NbKjspmO7l4AxkZHMXtyPOfOm0z+5Djyp8QzKz2e2LE6BI0k+m2IyHE51NnDtgNNbN3XyJZ9jWzd10hxdQs9/one+HHRzJ2SwH+cMo25GfHMnZJAbkqM5/fpkYEpEESkX+1dPeyobGZLRQObK3wB0PfgnxwzhnkZCSyfM4l5UxKYl5GgIZ8QpkAQEcA35l9c3cLmigY2VTSyuaKBospmunr+dfCfn5nAOflpzM9IYH5mAunx43TwDyMKBJEIdPgc/017G9lU0cDbexrYsq+RQ109AMSNi2ZBZgLXvT+XhZkJzM9MZEqCDv7hToEgEgFaO7p9B/69Dby1x/eqbekAfFf1zp0Sz8eWTGXh1AQWZiaSnRwTdjduk4EpEETCjHOOsro23iw/yJt7DvLmngaKKpvevV9/TkoMZ+SlsCgrkUVTE5mdHq+LuwRQIIiEvPauHjZXNLKx/CAb/SFw+CldceOiWTQ1kXPOzmOxPwASJ4zxuGIZqRQIIiGmrqWDQv/Bf0NZPVv3Nb478ZubEsPy2ZNYPG0ii7MmkjcpVkM/MmgKBJERruJgGxvK6nmj1PfaXdMKwJhRUSzITODa03MomJbEydMmkhSjb/9y4hQIIiOIc47S2lbeKK1nvT8A9jUcAnwXfBVkJ3HRyZksyU5ifkYC40aP3OfzSuhRIIh4yDlHSW0r60rqWFdSz/qSOqqbfWf/pMSOYWlOEtefkcvSnCRmpcVp+EcCSoEgEmR769t4bXctr+2u4/Xd/wqASXFjOXV6MstyklmWm0RuSozO+5egUiCIBFhNc4cvAHbV8VpJLXvrfUNAKbG+ADg1N5lTcpPIUQCIxxQIIsOsrbObN0rreaW4lld21bKjshnwzQGcOj2Z607P5X3Tk5kxKVYBICOKAkFkiHp7HdsONPFycS0vF9dQWHaQzp5exkRHUTBtIv+5ahanz0hh7pQEPdZRRjQFgsgJqGvp4OXiWtburGFtcQ21Lb4LweZMjufq07I5fUYKS7KTGD9GZwFJ6FAgiAxCb69jU0UDLxbV8M+iajbva8Q5SIoZw/vzUjgjL5X356UwKX6c16WKnDAFgkg/mtq7WLuzhhe2V/PPnTXUtXZiBoumJnLzipmcOTOV+RkJOhVUwoYCQaSP8rpWnttWxT+2V7OhrJ7uXkfihNGcOTOVs2dP4oy8VCbqamAJUwoEiWiHh4Ke3VbF89uqKK5uAWBmWiyfOiOX5bMnsWhqoh7/KBFBgSARp7O7l3UldTzzTiXPbauiurmDUVHG0uwkLl+axYo5aWQlT/C6TJGgUyBIRGjv6uGfO2t4emslz2+vorm9mwljRvGBWamszE/nrFmTSJgw2usyRTylQJCwdaizhxeLqnlyywFe3FFNW2cPiRNG88G56ayam87peSm6OZxIHwoECSvtXT28VFTNXzcf4IXt1Rzq6iEldgwXnpTBufMmsyw3idGaDxA5KgWChLyunl5eLq7hr5sO8Ow7lbR2+kLgopMz+ND8ySzLSdYVwiKDoECQkNTb69i45yCPv7WPp7Yc4GBbFwnjR3P+wimcv3AKy3KSdGaQyHHyJBDM7BLgm8AcYKlzrtCLOiT07K5p4bE39/H42/uoOHiI8aNHcU5+Gh9ZOIUzZqbqYfEiQ+BVD2Er8FHgVx7tX0JIQ1snf920n0ff3MemvQ1EGZyel8otK2eyMj+dmLHq6IoMB0/+T3LObQd061/pV0+v4+XiGh4prOC5bVV09vQyOz2OWz80h9WLpuieQSIBMOK/WpnZ9cD1AFlZWR5XI4G2t76NPxbu5ZHCCiqb2pk4YTRXLMvikoJM5k5J8Lo8kbAWsEAws+eB9KMsutU595fBbsc5twZYA1BQUOCGqTwZQTq7e3luWxUPvlHOq7vqiDI4c2Yqt5+fz/I5aZoXEAmSgAWCc25FoLYt4WFvfRsPvrGHRwr3UtvSSUbieL5wzkwuPjmTKYnjvS5PJOKM+CEjCS+9vY5/7qzhgXXlvFhUjQHL56RxxbIszshL1fUCIh7y6rTTC4H/AVKBJ83sbefcB72oRYKj8VAXjxTu5YF15ZTXtZEaN5Ybz5rBZUuz1BsQGSG8OsvoMeAxL/YtwVVa28p9r5by6MYK2jp7KJg2kVtWzmLV3HTNDYiMMBoykmHnnGNdST2/frmEF4qqGR0VxfkLp3DNadnMy9CZQiIjlQJBhk13Ty9Pba3k7rUlbNnXSHLMGG48O4+Pn5LFpDhdNyAy0ikQZMjau3p4pHAvv1pbQsXBQ+SmxPDdC+fz0cUZur20SAhRIMgJa27v4oF15dz7Sim1LZ2clJXIN87LZ8WcND14XiQEKRDkuDW2dXHfa6Xc92oZjYe6OGNmKp/7wHSW5STpdiQiIUyBIIPWeKiLe14p5b5XSmnu6GZlfho3nD2DBZmJXpcmIsNAgSADam7v4t5Xyvj1KyU0t3dz7rx0blqex5zJ8V6XJiLDSIEg/Wrv6uGB18v5xUu7ONjWxTn5ady8Yib5UxQEIuFIgSD/pqfX8ejGvfzkuWIqm9p5f14KX1w5i4VTNTQkEs4UCPIu5xwv7Kjme3/fQXF1C4umJvLTyxZxSm6y16WJSBAoEASAbfub+M6T23htdx05KTH835WLWTUvXWcNiUQQBUKEq23p4MfPFvHwhr0kjB/Ntz4ylyuWZTFaD6gXiTgKhAjV1dPLA6+X85Pnd3Kos4drT8vhprPzSJgw2uvSRMQjCoQItL6kjq//ZSs7q1p4f14Kt58/lxmTYr0uS0Q8pkCIIHUtHfz333fw6MYKMhLH86v/OJmV+WmaJxARQIEQEZxzPLqxgjue2k5Lezef/cB0bjo7j/FjdOM5EfkXBUKY21PXxtce28Iru2pZkj2ROy6cz8y0OK/LEpERSIEQpnp7Hb95rYwfPlPEqCjj2xfM48qlWboLqYj0S4EQhvbUtfGlRzexvrSes2alcseF8/XcYhEZkAIhjDjneHjDXr79t21EmfGDixdwycmZmjQWkUFRIISJ+tZOvvKnzTy7rYrTZiTzg4sXkqFegYgcBwVCGHiluJYv/PFtGtq6uO3Dc7j2tBzNFYjIcVMghLDunl5+9o9i7npxF9NTY7nvmiXMnZLgdVkiEqIUCCGqqqmdmx56i/Wl9VxycibfWj2XCWP06xSRE6cjSAhaX1LH/3vwTVo7erjz0oV8dHGm1yWJSBhQIIQQ53zXFtzx5Haykibw0KdOIU8XmYnIMFEghIj2rh6+9uct/PmtfayYk8adH1tI/DjdmVREho8CIQTUNHfw6QcKeXNPAzevmMmNZ8/QWUQiMuwUCCPc9gNNXHd/IXWtHfzflYs5d/5kr0sSkTClQBjBXi6u4TMPbCR2XDSPfPp9zM/UKaUiEjgKhBHq8bf28cVHNjFjUiy/uWYp6QnjvC5JRMKcAmEEWrN2N999agen5Cax5qoCTR6LSFAoEEYQ5xw/fKaIX7y0mw8vmMydly5kbLQeYiMiwaFAGCGcc3zrr9v4zWtlXLEsi++snqcziUQkqKK82KmZ/dDMdpjZZjN7zMwSvahjpOjpdXz1z1v4zWtlfPL0HO64QGEgIsHnSSAAzwHznHMLgJ3AVz2qw3O9vY6v/nkzD2/Yy41nz+C2D8/R8wtExBOeBIJz7lnnXLf/4zogIm/G45zjG09s5Y+FFdx09gxuWTlLYSAinvGqh9DXtcDf+1toZtebWaGZFdbU1ASxrMByzvFff9vG79bt4TNnTufmc2Z6XZKIRLiATSqb2fNA+lEW3eqc+4t/nVuBbuD3/W3HObcGWANQUFDgAlCqJ37y3E7ue7WMa0/L4cur1DMQEe8FLBCccyuOtdzMrgbOA5Y758LmQD8YD6wr5+cv7OLSgky+fp7mDERkZPDktFMzWwX8J3Cmc67Nixq88vTWA3zjL1tZPnsS371wvsJAREYMr+YQ7gLigOfM7G0z+6VHdQTVhrJ6bnr4bU6amshdVywmetRImMIREfHxpIfgnJvhxX69tLe+jU8/sJHMxPHc84kljB+jK5BFZGTRV9QgaOno5rr7C+nu6eXXnyhgYswYr0sSEfk3unVFgPX0Oj7/0Fvsqmnh/muWkpsa63VJIiJHpR5CgP3kuZ38Y0c1t5+fz+l5KV6XIyLSLwVCAL1YVM1dL+7iYwVTuerUbK/LERE5JgVCgOxvOMQX/vA2s9Pj+NbquV6XIyIyIAVCAHT19HLDg2/S2d3LL65czLjROqNIREa+AQPBzG40s4nBKCZc/OjZIt7c08D3LlqgSWQRCRmD6SGkARvM7I9mtsp0ae0xvVFaz5q1JVy+NIvzF07xuhwRkUEbMBCcc7cBecA9wNVAsZl918ymB7i2kNPS0c0tj7zN1IkTuO3Dc7wuR0TkuAxqDsF/87lK/6sbmAg8amY/CGBtIeeOJ7dTcfAQP750ITFjdYmHiISWAY9aZvZ54CqgFvg18CXnXJeZRQHF+G5SF/Fe3FHNQ2/s4dNn5LIkO8nrckREjttgvsYmAR91zpX3/aFzrtfMzgtMWaGltaObrz22hZlpsXrQjYiErAEDwTl3+zGWbR/eckLTT5/fyYHGdu664lSdYioiIUvXIQzR9gNN3PtqGZctmcrJ0zRUJCKhS4EwBL29jtse30rC+NF8edVsr8sRERkSBcIQPLqxgo3lB/nKubN1S2sRCXkKhBPU3N7F95/ewZLsiVy8ONPrckREhkyBcILuXltCXWsnXz8vn6goXbwtIqFPgXACqpvbufvlUs5bMJkFmYlelyMiMiwUCCfgZ88X09XTyxdXzvK6FBGRYaNAOE4lNS08vGEvVyzLIjslxutyRESGjQLhOP3o2SLGRUdx0/I8r0sRERlWCoTjUFTZzFNbKvnk6TmkxI71uhwRkWGlQDgOv/rnbiaMGcU1p+V4XYqIyLBTIAxSxcE2/rJpP5cvzdJFaCISlhQIg/Trl0sx4JOnq3cgIuFJgTAIdS0dPLxhDxeclMGUxPFelyMiEhAKhEG4//Vy2rt6+cyZuV6XIiISMAqEARzq7OH+18pYmZ/GjElxXpcjIhIwCoQB/G3zfhoPdXGt5g5EJMwpEAbw8Ia95KbGsCxHD78RkfCmQDiGnVXNbCw/yOVLsjDTHU1FJLwpEI7hoTf2MHqU8dHFGV6XIiIScAqEfrR39fDYW/v44Nx0knWbChGJAJ4Egpl928w2m9nbZvasmU3xoo5jeeadShraurh8aZbXpYiIBIVXPYQfOucWOOcWAX8DvuFRHf16cP0epiVP4NTcZK9LEREJCk8CwTnX1OdjDOC8qKM/5XWtrC+t52NLpurxmCISMaK92rGZ3QFcBTQCZx1jveuB6wGysoIzfPPUlkoAVi/SZLKIRI6A9RDM7Hkz23qU12oA59ytzrmpwO+BG/rbjnNujXOuwDlXkJqaGqhy3+PpdypZmJlAhu5bJCIRJGA9BOfcikGu+nvgKeD2QNVyPPY3HGLT3ga+vGq216WIiASVV2cZ9X3+5Gpghxd1HM3TW33DRavmpXtciYhIcHk1h/A9M5sF9ALlwGc8quPfPL21ktnpceSkxHhdiohIUHkSCM65i7zY70Cqm9vZUF7P55fnDbyyiEiY0ZXKfTz7ThXOwbnzJntdiohI0CkQ+nh6ayW5KTHMTIv1uhQRkaBTIPg1tHXyekkdH5yXrjubikhEUiD4rS2upafXsTI/zetSREQ8oUDwW1dSR9zYaOZnJHhdioiIJxQIfut217E0J4noUforEZHIpKMfUNXUTkltK6fozqYiEsEUCPiGiwAFgohENAUC/vmDcdHkT4n3uhQREc8oEIB1JfUsy0lilJ59ICIRLOIDobKxnVLNH4iIKBDWl2r+QEQEFAi8vruO+HHRzJms+QMRiWwRHwjrSupYmpOs+QMRiXgRHQgHGg9RVtfGKblJXpciIuK5iA6EN0rrAc0fiIhAhAfCtgNNjBkVxaz0OK9LERHxXEQHwo4DzUyfFMto3b9IRCSyA6Gospk56h2IiAARHAgHWzupbGrXcJGIiF/EBsKOymYAZuv6AxERIIIDoaiyCUBDRiIifhEbCDsqm5k4YTSpcWO9LkVEZESI6ECYnR6Pma5QFhGBCA2E3l7HzqpmTSiLiPQRkYGw92AbbZ09zJmsQBAROSwiA2H7Af8ZRuk6w0hE5LCIDISiymbMYGaaeggiIodFZCDsqGwiOzmG8WNGeV2KiMiIEaGB0Mws9Q5ERN4j4gLhUGcPZXWtzNaEsojIe0RcIOysasY5TSiLiBwp4gKh6PA9jHQNgojIe0RcIJTVtRIdZUxNmuB1KSIiI4qngWBmt5iZM7OUYO2zsqmdtPhxjIrSLStERPryLBDMbCqwEtgTzP1WNbUzKV43tBMROZKXPYSfAP8JuGDutKqpg/T4ccHcpYhISPAkEMxsNbDPObdpEOteb2aFZlZYU1Mz5H1XNfqGjERE5L2iA7VhM3seSD/KoluBr+EbLhqQc24NsAagoKBgSL2J1o5umju6FQgiIkcRsEBwzq042s/NbD6QA2zyP4sgE3jTzJY65yoDVQ/45g8A0hM0hyAicqSABUJ/nHNbgEmHP5tZGVDgnKsN9L4r/YGgHoKIyL+LqOsQqhQIIiL9CnoP4UjOuexg7auqqQNAZxmJiBxFRPUQKhvbiRsbTcxYz3NQRGTEiahA0EVpIiL9i7hASE/QcJGIyNFEWCB0aEJZRKQfERMIvb2OqiZdpSwi0p+ICYT6tk66e53OMBIR6UfEBEJlo65BEBE5logJhH9dlKazjEREjiaCAsF/UZrOMhIROaqICYTKpnbMIDVWPQQRkaOJmECoamwnJXYs0aMipskiIsclYo6OVc3tOsNIROQYIiYQKvWkNBGRY4qYQPBdlKb5AxGR/kREILR39XCwrUtDRiIixxARgVDT7DvlNE2nnIqI9CsiAkGPzhQRGVhkBIL/thUaMhIR6V9EBMLh21YoEERE+hcxgTA2Oor48Xp0pohIfyIiEKanxnLBogzMzOtSRERGrIj4ynzZ0iwuW5rldRkiIiNaRPQQRERkYAoEEREBFAgiIuKnQBAREUCBICIifgoEEREBFAgiIuKnQBAREQDMOed1DYNmZjVA+XH8kRSgNkDljGRqd2SJ1HZD5Lb9eNs9zTmXOtBKIRUIx8vMCp1zBV7XEWxqd2SJ1HZD5LY9UO3WkJGIiAAKBBER8Qv3QFjjdQEeUbsjS6S2GyK37QFpd1jPIYiIyOCFew9BREQGSYEgIiJAmASCma0ysyIz22VmXznK8rFm9gf/8vVmlh38KoffINr9BTPbZmabzewfZjbNizqH20Dt7rPeRWbmzCwsTkscTLvN7FL/7/wdM3sw2DUGwiD+nWeZ2Ytm9pb/3/qHvKhzuJnZvWZWbWZb+1luZvZz/9/LZjNbPOSdOudC+gWMAnYDucAYYBOQf8Q6nwN+6X9/GfAHr+sOUrvPAib43382UtrtXy8OWAusAwq8rjtIv+884C1gov/zJK/rDlK71wCf9b/PB8q8rnuY2n4GsBjY2s/yDwF/Bww4BVg/1H2GQw9hKbDLOVfinOsEHgZWH7HOauB+//tHgeUW+g9YHrDdzrkXnXNt/o/rgMwg1xgIg/l9A3wb+D7QHsziAmgw7f4U8L/OuYMAzrnqINcYCINptwPi/e8TgP1BrC9gnHNrgfpjrLIa+K3zWQckmtnkoewzHAIhA9jb53OF/2dHXcc51w00AslBqS5wBtPuvj6J79tEqBuw3f6u81Tn3JPBLCzABvP7ngnMNLNXzWydma0KWnWBM5h2fxP4uJlVAE8BNwanNM8d7zFgQNFDKkdCgpl9HCgAzvS6lkAzsyjgTuBqj0vxQjS+YaMP4OsNrjWz+c65Bk+rCrzLgd84535sZqcCD5jZPOdcr9eFhZpw6CHsA6b2+Zzp/9lR1zGzaHzdyrqgVBc4g2k3ZrYCuBX4iHOuI0i1BdJA7Y4D5gEvmVkZvrHVJ8JgYnkwv+8K4AnnXJdzrhTYiS8gQtlg2v1J4I8AzrnXgXH4bv4W7gZ1DDge4RAIG4A8M8sxszH4Jo2fOGKdJ4BP+N9fDLzg/LMyIWzAdpvZScCv8IVBOIwnwwDtds41OudSnHPZzrlsfHMnH3HOFXpT7rAZzL/zx/H1DjCzFHxDSCXBLDIABtPuPcByADObgy8QaoJapTeeAK7yn210CtDonDswlA2G/JCRc67bzG4AnsF3RsK9zrl3zOy/gELn3BPAPfi6kbvwTdJc5l3Fw2OQ7f4hEAs84p9D3+Oc+4hnRQ+DQbY77Ayy3c8AK81sG9ADfMk5F9I94UG2+xbgbjO7Gd8E89Vh8IUPM3sIX8Cn+OdHbgdGAzjnfolvvuRDwC6gDbhmyPsMg783EREZBuEwZCQiIsPq2B6dAAAA2UlEQVRAgSAiIoACQURE/BQIIiICKBBERMRPgSAiIoACQURE/BQIIkNgZkv896IfZ2Yx/ucQzPO6LpEToQvTRIbIzL6D73YJ44EK59x/e1ySyAlRIIgMkf8eOxvwPXvhfc65Ho9LEjkhGjISGbpkfPeMisPXUxAJSeohiAyRmT2B70leOcBk59wNHpckckJC/m6nIl4ys6uALufcg2Y2CnjNzM52zr3gdW0ix0s9BBERATSHICIifgoEEREBFAgiIuKnQBAREUCBICIifgoEEREBFAgiIuL3/wHQ0c3EEJFLIgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "x = np.arange(0.01,1,0.01)\n",
    "y = np.log(x)\n",
    "plt.title(\"y=log(x)\") \n",
    "plt.xlabel(\"x\") \n",
    "plt.ylabel(\"y\") \n",
    "plt.plot(x,y)\n",
    "plt.show()\n",
    "plt.figure()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "如自然对数的图形所示，当x等于1时，y为0；随着x向0靠近，y逐渐变小。因此，正确解标签对应的输出越大，交叉熵的值越接近0；当输出为1时，交叉熵误差为0。反之，如果正确解标签对应的输出越小，则交叉熵的值越大。\n",
    "\n",
    "在手写数字识别任务中，如果在现有代码中将模型的损失函数替换成交叉熵（cross_entropy），仅改动三行代码即可：在读取数据部分，将标签的类型设置成int，体现它是一个标签而不是实数值（飞桨框架默认将标签处理成int64）；在网络定义部分，将输出层改成“输出十个标签的概率”的模式；以及在训练过程部分，将损失函数从均方误差换成交叉熵。\n",
    "\n",
    "- 数据处理部分：label = np.reshape(labels[i], [1]).astype('int64')\n",
    "- 网络定义部分：self.fc = FC(name_scope, size=10, act='softmax')\n",
    "- 训练过程部分：loss = fluid.layers.cross_entropy(predict, label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "如下是在数据处理部分，修改标签变量Label的格式。\n",
    "- 从：label = np.reshape(labels[i], [1]).astype('float32')\n",
    "- 到：label = np.reshape(labels[i], [1]).astype('int64')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#修改标签数据的格式，从float32到int64\n",
    "import os\n",
    "import random\n",
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "import gzip\n",
    "import json\n",
    "\n",
    "# 定义数据集读取器\n",
    "def load_data(mode='train'):\n",
    "\n",
    "    # 数据文件\n",
    "    datafile = './work/mnist.json.gz'\n",
    "    print('loading mnist dataset from {} ......'.format(datafile))\n",
    "    data = json.load(gzip.open(datafile))\n",
    "    train_set, val_set, eval_set = data\n",
    "\n",
    "    # 数据集相关参数，图片高度IMG_ROWS, 图片宽度IMG_COLS\n",
    "    IMG_ROWS = 28\n",
    "    IMG_COLS = 28\n",
    "\n",
    "    if mode == 'train':\n",
    "        imgs = train_set[0]\n",
    "        labels = train_set[1]\n",
    "    elif mode == 'valid':\n",
    "        imgs = val_set[0]\n",
    "        labels = val_set[1]\n",
    "    elif mode == 'eval':\n",
    "        imgs = eval_set[0]\n",
    "        labels = eval_set[1]\n",
    "\n",
    "    imgs_length = len(imgs)\n",
    "\n",
    "    assert len(imgs) == len(labels), \\\n",
    "          \"length of train_imgs({}) should be the same as train_labels({})\".format(\n",
    "                  len(imgs), len(labels))\n",
    "\n",
    "    index_list = list(range(imgs_length))\n",
    "\n",
    "    # 读入数据时用到的batchsize\n",
    "    BATCHSIZE = 100\n",
    "\n",
    "    # 定义数据生成器\n",
    "    def data_generator():\n",
    "        if mode == 'train':\n",
    "            random.shuffle(index_list)\n",
    "        imgs_list = []\n",
    "        labels_list = []\n",
    "        for i in index_list:\n",
    "            img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')\n",
    "            label = np.reshape(labels[i], [1]).astype('int64')\n",
    "            imgs_list.append(img) \n",
    "            labels_list.append(label)\n",
    "            if len(imgs_list) == BATCHSIZE:\n",
    "                yield np.array(imgs_list), np.array(labels_list)\n",
    "                imgs_list = []\n",
    "                labels_list = []\n",
    "\n",
    "        # 如果剩余数据的数目小于BATCHSIZE，\n",
    "        # 则剩余数据一起构成一个大小为len(imgs_list)的mini-batch\n",
    "        if len(imgs_list) > 0:\n",
    "            yield np.array(imgs_list), np.array(labels_list)\n",
    "\n",
    "    return data_generator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "如下是在网络定义部分，修改输出层结构。\n",
    "- 从：self.fc = FC(name_scope, size=1, act=None)\n",
    "- 到：self.fc = FC(name_scope, size=10, act='softmax')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 定义模型结构\n",
    "class MNIST(fluid.dygraph.Layer):\n",
    "     def __init__(self, name_scope):\n",
    "         super(MNIST, self).__init__(name_scope)\n",
    "         name_scope = self.full_name()\n",
    "         # 定义一个卷积层，使用relu激活函数\n",
    "         self.conv1 = Conv2D(name_scope, num_filters=20, filter_size=5, stride=1, padding=2, act='relu')\n",
    "         # 定义一个池化层，池化核为2，步长为2，使用最大池化方式\n",
    "         self.pool1 = Pool2D(name_scope, pool_size=2, pool_stride=2, pool_type='max')\n",
    "         # 定义一个卷积层，使用relu激活函数\n",
    "         self.conv2 = Conv2D(name_scope, num_filters=20, filter_size=5, stride=1, padding=2, act='relu')\n",
    "         # 定义一个池化层，池化核为2，步长为2，使用最大池化方式\n",
    "         self.pool2 = Pool2D(name_scope, pool_size=2, pool_stride=2, pool_type='max')\n",
    "         # 定义一个全连接层，输出节点数为10 \n",
    "         self.fc = FC(name_scope, size=10, act='softmax')\n",
    "    # 定义网络的前向计算过程\n",
    "     def forward(self, inputs):\n",
    "         x = self.conv1(inputs)\n",
    "         x = self.pool1(x)\n",
    "         x = self.conv2(x)\n",
    "         x = self.pool2(x)\n",
    "         x = self.fc(x)\n",
    "         return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "如下代码仅修改计算损失的函数，从均方误差（常用于回归问题）到交叉熵误差（常用于分类问题）。\n",
    "- 从：loss = fluid.layers.square_error_cost(predict, label)\n",
    "- 到：loss = fluid.layers.cross_entropy(predict, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading mnist dataset from ./work/mnist.json.gz ......\n",
      "epoch: 0, batch: 0, loss is: [2.4464767]\n",
      "epoch: 0, batch: 200, loss is: [0.35378253]\n",
      "epoch: 0, batch: 400, loss is: [0.19394642]\n",
      "epoch: 1, batch: 0, loss is: [0.2654959]\n",
      "epoch: 1, batch: 200, loss is: [0.26667175]\n",
      "epoch: 1, batch: 400, loss is: [0.22529164]\n",
      "epoch: 2, batch: 0, loss is: [0.15588592]\n",
      "epoch: 2, batch: 200, loss is: [0.20306593]\n",
      "epoch: 2, batch: 400, loss is: [0.25867027]\n",
      "epoch: 3, batch: 0, loss is: [0.2143673]\n",
      "epoch: 3, batch: 200, loss is: [0.10330768]\n",
      "epoch: 3, batch: 400, loss is: [0.16994831]\n",
      "epoch: 4, batch: 0, loss is: [0.19997822]\n",
      "epoch: 4, batch: 200, loss is: [0.05874193]\n",
      "epoch: 4, batch: 400, loss is: [0.1538875]\n"
     ]
    }
   ],
   "source": [
    "#仅修改计算损失的函数，从均方误差（常用于回归问题）到交叉熵误差（常用于分类问题）\n",
    "with fluid.dygraph.guard():\n",
    "    model = MNIST(\"mnist\")\n",
    "    model.train()\n",
    "    #调用加载数据的函数\n",
    "    train_loader = load_data('train')\n",
    "    optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.01)\n",
    "    EPOCH_NUM = 5\n",
    "    for epoch_id in range(EPOCH_NUM):\n",
    "        for batch_id, data in enumerate(train_loader()):\n",
    "            #准备数据，变得更加简洁\n",
    "            image_data, label_data = data\n",
    "            image = fluid.dygraph.to_variable(image_data)\n",
    "            label = fluid.dygraph.to_variable(label_data)\n",
    "            \n",
    "            #前向计算的过程\n",
    "            predict = model(image)\n",
    "            \n",
    "            #计算损失，使用交叉熵损失函数，取一个批次样本损失的平均值\n",
    "            loss = fluid.layers.cross_entropy(predict, label)\n",
    "            avg_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            #每训练了200批次的数据，打印下当前Loss的情况\n",
    "            if batch_id % 200 == 0:\n",
    "                print(\"epoch: {}, batch: {}, loss is: {}\".format(epoch_id, batch_id, avg_loss.numpy()))\n",
    "            \n",
    "            #后向传播，更新参数的过程\n",
    "            avg_loss.backward()\n",
    "            optimizer.minimize(avg_loss)\n",
    "            model.clear_gradients()\n",
    "\n",
    "    #保存模型参数\n",
    "    fluid.save_dygraph(model.state_dict(), 'mnist')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "虽然上述训练过程的损失明显比使用均方误差算法要小，但因为损失函数量纲的变化，我们无法从比较两个不同的Loss得出谁更加优秀。怎么解决这个问题呢？我们可以回归到问题的直接衡量，谁的分类准确率高来判断。在后面介绍完计算准确率和作图的内容后，读者可以自行测试采用不同损失函数下，模型准确率的高低。\n",
    "\n",
    "因为我们修改了模型的输出格式，所以使用模型做预测时的代码也需要做相应的调整。从模型输出10个标签的概率中选择最大的，将其标签编号输出。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "本次预测的数字是:  0\n"
     ]
    }
   ],
   "source": [
    "# 读取一张本地的样例图片，转变成模型输入的格式\n",
    "def load_image(img_path):\n",
    "    # 从img_path中读取图像，并转为灰度图\n",
    "    im = Image.open(img_path).convert('L')\n",
    "    im.show()\n",
    "    im = im.resize((28, 28), Image.ANTIALIAS)\n",
    "    im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)\n",
    "    # 图像归一化\n",
    "    im = 1.0 - im / 255.\n",
    "    return im\n",
    "\n",
    "# 定义预测过程\n",
    "with fluid.dygraph.guard():\n",
    "    model = MNIST(\"mnist\")\n",
    "    params_file_path = 'mnist'\n",
    "    img_path = './work/example_0.jpg'\n",
    "    # 加载模型参数\n",
    "    model_dict, _ = fluid.load_dygraph(\"mnist\")\n",
    "    model.load_dict(model_dict)\n",
    "    \n",
    "    model.eval()\n",
    "    tensor_img = load_image(img_path)\n",
    "    #模型反馈10个分类标签的对应概率\n",
    "    results = model(fluid.dygraph.to_variable(tensor_img))\n",
    "    #取概率最大的标签作为预测输出\n",
    "    lab = np.argsort(results.numpy())\n",
    "    print(\"本次预测的数字是: \", lab[0][-1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PaddlePaddle 1.6.0 (Python 3.5)",
   "language": "python",
   "name": "py35-paddle1.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
